# Copyright (c) 2023 Copyright holder of the paper "Revisiting Image Classifier Training for Improved Certified Robust Defense against Adversarial Patches" submitted to TMLR for review

# All rights reserved.

import argparse
import numpy as np
import math
import random
import time
import os

import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import timm

from build import generate_masks
from utils import *

# python eval.py --dataset imagenet

datasets = ["imagenet", "cifar10", "cifar100", "imagenette"]

parser = argparse.ArgumentParser(description='ImageNet Clean Evaluation')
parser.add_argument('-j', '--workers', default=6, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('-p', '--print-freq', default=1000, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--seed', default=2022, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--dataset', metavar='dataset', default='imagenet',
                    choices=datasets)
args = parser.parse_args()


ImageNet_Mean = [0.485, 0.456, 0.406]  # [0.5, 0.5, 0.5]
ImageNet_Std = [0.229, 0.224, 0.225] # [0.5, 0.5, 0.5]

Inception_Mean = [0.5, 0.5, 0.5]
Inception_Std = [0.5, 0.5, 0.5]


def normalize_test(tensor, stats):
    if stats == "inception":
        mean, std = Inception_Mean, Inception_Std
    else:
        mean, std = ImageNet_Mean, ImageNet_Std
    channel1 = (tensor[:, [0], :, :] - mean[0]) / std[0]
    channel2 = (tensor[:, [1], :, :] - mean[1]) / std[1]
    channel3 = (tensor[:, [2], :, :] - mean[2]) / std[2]
    tensor = torch.cat((channel1, channel2, channel3), dim=1)
    return tensor


def validate_clean(val_loader, classifier, stats='inception'):
    classifier.eval()
    top1 = AverageMeter('Clean Acc@1', ':6.2f')
    eval_start_time = time.time()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            clean_images, target = data[0].cuda(), data[1].cuda()
            normalized_images = normalize_test(clean_images, stats)
            output = classifier(normalized_images)
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            top1.update(acc1[0], target.size(0))
        if stats != 'inception':
            print('ImageNet mean and std: Clean Acc@1 {top1.avg:.3f} Time {Time:.3f} secs'.format(top1=top1, Time=time.time() - eval_start_time))
        else:
            print('Inception mean and std: Clean Acc@1 {top1.avg:.3f} Time {Time:.3f} secs'.format(top1=top1,
                                                                                                  Time=time.time() - eval_start_time))


def main():
    cudnn.benchmark = True
    args = parser.parse_args()
    if args.seed is not None:
        # set the seed
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    print("Dataset:", args.dataset)
    print("\n")

    train_loader, val_loader = get_dataloaders(args.dataset)
    num_classes = get_num_classes(args.dataset)

    for pretrained_model in ["resnetv2", "vit_base", "convnext"]:
        print("Pretrained model: ", pretrained_model)
        if pretrained_model == "resnetv2":
            classifier = timm.create_model('resnetv2_50x1_bit_distilled', pretrained=True)
        elif pretrained_model == "vit_base":
            classifier = timm.create_model('vit_base_patch16_224', pretrained=True)
        elif pretrained_model == "convnext":
            classifier = timm.create_model('convnext_tiny_in22ft1k', pretrained=True)

        if num_classes != 1000:
            classifier.reset_classifier(num_classes=num_classes)

        classifier.cuda()
        validate_clean(val_loader, classifier, stats='inception')
        validate_clean(val_loader, classifier, stats='imagenet')


if __name__ == '__main__':
    main()